import argparse
import json
from pathlib import Path

from rouge import Rouge
from tqdm import tqdm

try:
    from nltk.translate.bleu_score import corpus_bleu
except ImportError as e:
    raise ImportError(
        "Please install nltk to run this script: pip install nltk && python -m nltk.downloader punkt"
    ) from e


def get_metrics(ground_truth, generated_text):
    keys = list(ground_truth.keys())
    references = [[ground_truth[key]] for key in keys]
    hypotheses = [generated_text.get(key, "") for key in keys]

    filtered_references = []
    filtered_hypotheses = []
    for ref, hyp in zip(references, hypotheses):
        if hyp.strip() and ref[0].strip():
            filtered_references.append(ref)
            filtered_hypotheses.append(hyp)

    if not filtered_hypotheses:
        return 0.0, 0.0, 0.0, 0.0, 0.0

    bleu_1_score = corpus_bleu(filtered_references, filtered_hypotheses, weights=(1, 0, 0, 0))
    bleu_2_score = corpus_bleu(
        filtered_references, filtered_hypotheses, weights=(0.5, 0.5, 0, 0)
    )
    bleu_3_score = corpus_bleu(
        filtered_references, filtered_hypotheses, weights=(0.33, 0.33, 0.33, 0)
    )
    bleu_4_score = corpus_bleu(
        filtered_references, filtered_hypotheses, weights=(0.25, 0.25, 0.25, 0.25)
    )

    rouge = Rouge()
    scores = rouge.get_scores(
        filtered_hypotheses, [ref[0] for ref in filtered_references], avg=True
    )
    rouge_l_score = scores["rouge-l"]["f"]

    return bleu_1_score, bleu_2_score, bleu_3_score, bleu_4_score, rouge_l_score


def main():
    parser = argparse.ArgumentParser(
        description="Compute corpus BLEU (1-4) and ROUGE-L scores from JSON predictions."
    )
    parser.add_argument(
        "input_dir", type=str, help="Directory containing JSON prediction files."
    )
    parser.add_argument(
        "--output",
        type=str,
        default="gpt_old_metrics.txt",
        help="Output .txt file to write metrics.",
    )
    args = parser.parse_args()

    input_path = Path(args.input_dir)
    if not input_path.is_dir():
        raise FileNotFoundError(
            f"Input directory {input_path} does not exist or is not a directory."
        )

    json_files = sorted(input_path.glob("*.json"))
    if not json_files:
        raise FileNotFoundError(f"No JSON files found in {input_path}.")

    header = "Filename\tBLEU-1\tBLEU-2\tBLEU-3\tBLEU-4\tROUGE-L\n"
    lines = [header]

    for jf in json_files:
        with open(jf, "r", encoding="utf-8") as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError as e:
                print(f"Failed to parse {jf}: {e}")
                continue

        c = 1
        ground_truth = {}
        generated_text = {}
        for i in tqdm(range(len(data)), desc=f"{jf.name}", leave=False):
            try:
                gt = data[i].get("ground_truth", "")
                # Use 'moderator_tweet' if present; otherwise fall back to 'generated'.
                pred = data[i].get("moderator_tweet", data[i].get("generated", ""))
                # Remove <hyperlink> occurrences from ground truth
                gt = gt.replace("<hyperlink>", "")
                if gt.strip() and pred.strip():
                    ground_truth[f"question{c}"] = gt.strip()
                    generated_text[f"question{c}"] = pred.strip()
                c += 1
            except Exception:
                continue

        bleu1, bleu2, bleu3, bleu4, rouge_l = get_metrics(
            ground_truth, generated_text
        )
        line = f"{jf.name}\t{bleu1:.4f}\t{bleu2:.4f}\t{bleu3:.4f}\t{bleu4:.4f}\t{rouge_l:.4f}\n"
        lines.append(line)

    with open(args.output, "w", encoding="utf-8") as f:
        f.writelines(lines)

    print(f"Metrics for {len(json_files)} file(s) written to {args.output}")


if __name__ == "__main__":
    main()
